import random
from typing import Optional


class ListNode:
    def __init__(self, data: Optional[int] = None):
        # 链表节点的数据与，目的是方便创建头节点
        self._data = data
        # 存储各个索引级中该节点的后驱索引节点
        self._forwards = []


class Skiplist:
    # 设置调表最大层级
    _MAX_LEVEL = 16

    def __init__(self):
        # 初始化层级1
        self._level_count = 1
        self._head = ListNode()
        self._head._forwards = [None] * self._MAX_LEVEL

    def search(self, target: int) -> bool:
        p = self._head
        # 从最高级开始搜索,如果当前层级没有,则下沉到低一层级
        for i in range(self._level_count - 1, -1, -1):
            while p._forwards[i] and p._forwards[i]._data < target:
                p = p._forwards[i]

        if p._forwards[0] and p._forwards[0]._data == target:
            return True
        return False

    def add(self, num: int) -> None:
        # 随机生成索引层级
        level = self._random_level()
        # 如果当前最大层级 小于 随机到的层级,就更新最大层级
        if self._level_count < level:  # 如果当前层级小于  level,则更新当前最高层级
            self._level_count = level
        # 生成新节点
        new_node = ListNode(num)
        # 新节点的后驱节点
        new_node._forwards = [None] * level
        # 用来保存各个索引层级插入的位置,新节点的前驱节点
        update = [self._head] * self._level_count

        p = self._head
        # 用来获取新插入节点在各个索引层级的前驱节点,从最高级开始循环
        for i in range(self._level_count - 1, -1, -1):
            while p._forwards[i] and p._forwards[i]._data < num:
                p = p._forwards[i]

            update[i] = p

        # 更新需要更新的各个索引层级,i代表层级
        for i in range(level):
            new_node._forwards[i] = update[i]._forwards[i]
            update[i]._forwards[i] = new_node

    # 抹去
    def erase(self, num: int) -> bool:
        update = [None] * self._level_count
        p = self._head
        # 获取前置索引
        for i in range(self._level_count - 1, -1, -1):
            while p._forwards[i] and p._forwards[i]._data < num:
                p = p._forwards[i]
            update[i] = p
        # 如果是第0层索引的 0 个元素
        if p._forwards[0] and p._forwards[0]._data == num:
            for i in range(self._level_count - 1, -1, -1):
                if update[i]._forwards[i] and update[i]._forwards[i]._data == num:
                    update[i]._forwards[i] = update[i]._forwards[i]._forwards[i]
            return True
        while self._level_count > 1 and not self._head._forwards[self._level_count]:
            self._level_count -= 1
        return False

    def _random_level(self, p: float = 0.5) -> int:
        level = 1
        # 通过不断地 * 1/2 来叠加概率
        while random.random() < p and level < self._MAX_LEVEL:
            level += 1
        return level


skiplist = Skiplist()
skiplist.add(1)
skiplist.add(2)
skiplist.add(3)
skiplist.search(0)
skiplist.add(4)
skiplist.search(1)
skiplist.erase(0)
skiplist.erase(1)
skiplist.search(1)
